import os
import re
import json
import random
import torch
from rich import print
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline
from policies import BasePolicy
from policies.utils.replay_buffer import ReplayBuffer
from policies.utils.reflection import Reflection
import policies.prompts as prompts

from policies.base_policy import BasePolicy
from policies.utils.vector_replay_buffer import VectorReplayBuffer
from policies.utils.reflection import Reflection, get_obs_message

MODELS = [
    "meta-llama/Meta-Llama-3-70B-Instruct",
]

class DiLUPolicy(BasePolicy):
    """
    This policy is based on the idea of few-shot learning. The agent will store the few-shot examples in the vector memory
    """
    def __init__(self,
                 model="meta-llama/Meta-Llama-3-8B-Instruct",
                 agent_id="",
                 temperature=0.2,
                 adapter=None,
                 device="cuda",
                 comm_only=False,
                 control_only=False,
                 skip_frames=0,
                 vector_memory_database="vector_memory",
                 num_few_shot_examples=3,
                 batch_size=1,
                 is_focal=False
    ):
        # basic policy setup
        self.agent_id = agent_id
        self.decision_frequency = 10 # frames, decision_frequency / frame_rate = seconds per decision
        self.comm_only = comm_only # indicate whether the agent is communication only
        self.control_only = control_only # indicate whether the agent is control only
        self.skip_frames = skip_frames
        self.is_focal = is_focal

         # set up language model
        self.model = AutoModelForCausalLM.from_pretrained(
                                                          model,
                                                          torch_dtype=torch.bfloat16
                                                          )
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        if adapter is not None:
            self.model.load_adapter(adapter)
        self.terminators = [
                            self.tokenizer.eos_token_id,
                            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
                            ]
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
        self.temperature = temperature

        # basic prompts
        self.instruction = prompts.get_instruction(comm_only)
        self.common_sense = prompts.get_common_sense()
        self.message_history = []

        # set up data base for storing few-shot examples
        self.vector_memory_database = vector_memory_database
        self.vector_memory = VectorReplayBuffer(
            data_path=os.path.join(os.getcwd(), vector_memory_database)
        )

        # set up learning module
        self.replay_buffer = ReplayBuffer()
        self.batch_size = batch_size
        self.iteration = 0
        self.experience = None
        self.current_observation = None

        # reflection for self-correction
        self.reflection = Reflection(model=self.model,
                                     tokenizer=self.tokenizer,
                                     device=self.device,
                                     temperature=self.temperature,
                                     comm_only=self.comm_only,
                                     control_only=self.control_only,
                                     is_focal=self.is_focal
                                     )

        # policy specific setup
        self.num_few_shot_examples = num_few_shot_examples
 
        # episodic metric setup
        self.step_count = 0
        self.episode_return = 0
        self.prev_action = None
        self.plan = None

    def reset(self):
        """
        Reset the agent when a new episode starts
        """
        self.step_count = 0
        self.episode_return = 0
        self.current_observation = None
        self.prev_action = None
        self.plan = None
        self.message_history = []
        self.iteration = 0
        self.learned_knowledge = ""
        print("Resetting the agent")

    def observe(self, obs, reward, terminated, truncated, info):
        self.current_observation = obs
        self.episode_return += reward

    def act(self):
        self.step_count += 1
        if self.step_count <= self.skip_frames:
            return {"command":"go", "message":""}
        if self.step_count % self.decision_frequency == 1:
            # 1. retrive few shot examples from the memory
            few_shot_obs, few_shot_reasoning, few_shot_actions = self.retrieve_few_shot_obs()
            # 2. prompt with few shot examples
            response = self.prompting(few_shot_obs, few_shot_reasoning, few_shot_actions)
            # 3. parse the action from the response
            action = self.parse_action(response)
            self.prev_action = action

        action = self.prev_action
        return action

    def prompting(self, few_shot_obs=[], few_shot_reasoning=[], few_shot_actions=[])->str:
        """
        Generate the chain of thought prompting
        """
        response = {}

        # System message step
        self.message_history = [{"role":"system", "content":self.instruction}]
        self.message_history.append({"role":"system", "content":self.common_sense})

        # In-context few-shot prompting
        few_shot_prompt = prompts.get_few_shot_prompt(few_shot_obs, few_shot_reasoning, few_shot_actions)
        self.message_history.extend(few_shot_prompt)

        # Prompt observation
        observation = self.current_observation
        if observation is None:
            obs_message = "No observation"
        else:
            obs_message = get_obs_message(observation)
        self.message_history.append({"role":"user", "content": obs_message})

        # Reasoning prompting
        cot_prompt1 = prompts.get_cot_prompt_1()
        response["prompt"] = cot_prompt1
        self.message_history.append({"role":"user", "content":cot_prompt1})
        response["reasoning"] = self.chat(self.message_history)
        self.message_history.append({"role":"assistant", "content":response["reasoning"]})

        # Decision prompting
        cot_prompt2 = prompts.get_cot_prompt_2(self.comm_only, self.control_only)
        self.message_history.append({"role":"user", "content":cot_prompt2})
        response["action"] = self.chat(self.message_history, json_format=True)

        # Record the action generated by the model
        self.message_history.append({"role":"assistant", "content":response["action"]})
        print(self.message_history)
        return response

    def chat(self, prompt, json_format=False):
        prompt = self.tokenizer.apply_chat_template(prompt,
                                                    add_generation_prompt=True,
                                                    return_tensors='pt').to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(prompt,
                                          max_length=8192,
                                          temperature=self.temperature,
                                          eos_token_id=self.terminators
                                          )
        response = self.tokenizer.decode(outputs[0][prompt.shape[-1]:], skip_special_tokens=True)
        return response

    def parse_action(self, response):
        """
        Parse the action into a dictionary
        """
        action = {"command":"go", "reasoning":response["reasoning"]}
        try:
            response = re.findall(r"\{[^*]*\}", response["action"])[0]
            if response:
                response = json.loads(response)
                if "command" in response:
                    action["command"] = response["command"]
                if "message" in response:
                    action["message"] = response["message"]
        except:
            print("Error in parsing the response", response)
        action = dict(sorted(action.items(), key=lambda item: item[0]))
        return action

    def get_episode_return(self)->float:
        return self.episode_return

    def retrieve_few_shot_obs(self):
        if self.num_few_shot_examples == 0:
            return [], [], []
        obs = str(self.current_observation)
        few_shot_results = self.vector_memory.retrive(state=obs,
                                                      k=self.num_few_shot_examples)
        few_shot_obs = []
        few_shot_reasoning = []
        few_shot_actions = []
        for few_shot_result in few_shot_results:
            few_shot_obs.append(
                few_shot_result["obs"])
            few_shot_reasoning.append(few_shot_result["reasoning"])
            few_shot_actions.append(few_shot_result["action"])
            mode_action = max(
                set(few_shot_actions), key=few_shot_actions.count)
            mode_action_count = few_shot_actions.count(mode_action)            
        return few_shot_obs, few_shot_reasoning, few_shot_actions

    def store_transition(self, transition)->None:
        """
        Store the transition in replay buffer for learning
        """
        if transition.obs is None:
            return
        self.replay_buffer.add(transition)

    def learn(self)->None:
        """
        Learn from the experince
        """
        self.iteration += 1        
        # 1. get environment feedback
        batch = self.replay_buffer.sample_batch(batch_size=self.batch_size)
        for transition in batch:
            comments = "original output"
            # 2. if we get into collision or stagnation, ask model to correct them
            if (transition.feedback["collision_info"]["collision_occurred"] or 
                transition.feedback["stagnation_info"]["stagnation_occurred"]):
                # 3. ask the language model to self-correct
                transition = self.reflection.correct(transition)
                comments = "corrected output"
            # 4. store the sampled (both corrected and raw) memory item back to the memory
            try:
                self.vector_memory.add(
                                    state=str(transition.obs),
                                    obs=str(transition.obs),
                                    reasoning=transition.action.get("reasoning", ""),
                                    action=str({"command":transition.action.get("command", None),
                                            "message":transition.action.get("message", None)}),
                                    comments=comments
                                    )
            except Exception as e:
                print("Error in storing the memory item", e)    

    def load(self, ckpt_num):
        print("Loading the model from vector memory: ", self.vector_memory_database)
        return super().load(ckpt_num)
    
    def save(self, ckpt_num):
        return super().save(ckpt_num)

if __name__ == "__main__":
    policy = DiLUPolicy(vector_memory_database="test_debug")
    policy.reset()
    policy.observe({"observation":"haha", "received_message":"hi"}, 1, False, False, {})
    policy.step_count = 150
    policy.act()
    from IPython import embed; embed()
